# -*-Python-*-

import mesh_tensorflow.transformer
import mesh_tensorflow.transformer.adaptive_softmax

decoder/Unitransformer.loss_fn = @adaptive_softmax.adaptive_softmax_loss_fn

transformer.get_vocab_embedding_cls.cls = @adaptive_softmax.AdaptiveSoftmaxVocabEmbedding

# Required for the adaptive softmax to work.
Unitransformer.shared_embedding_and_softmax_weights = True

# These parameters will be silently ignored by the adaptive softmax.
Unitransformer.label_smoothing = 0.0
Unitransformer.loss_on_targets_only = False

# You will need to set the clusters like this:
# adaptive_softmax.AdaptiveSoftmaxVocabEmbedding.clusters = [
#         {
#           "token_count": 50000,
#           "embedding_size": 1024,
#         }, {
#           "token_count": 100000,
#           "embedding_size": 256,
#           "length_projection_factor": 0.25,
#         }, {
#           "token_count": 350000,
#           "embedding_size": 64,
#           "length_projection_factor": 0.125,
#         },
#        ]
